from torch.utils.data import Dataset
from datasets import load_from_disk

class MyDataset(Dataset):
    def __init__(self, split):   
        self.dataset = load_from_disk('D:/AI/HuggingFace/data/ChnSentiCorp/chn_senti_corp')
        # 根据传入的参数split划分数据集（这里是train test validation三种数据）
        if split == 'train':
            self.dataset = self.dataset['train']
        elif split == 'test':
            self.dataset = self.dataset['test']
        elif split == 'validation':
            self.dataset = self.dataset['validation']
        else:
            raise ValueError('数据集名称错误！请核对加载的数据集')

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        text = self.dataset[item]['text']
        label = self.dataset[item]['label']
        return text, label
""" if __name__ == '__main__':
    dataset = Mydataset("test")
    for data in dataset:
        print(data)
        break """